We report a number of developing ideas on the Anthropic interpretability team, which might be of interest to researchers working actively in this space. Some of these are emerging strands of research where we expect to publish more on in the coming months. Others are minor points we wish to share, since we're unlikely to ever write a paper about them.
We'd ask you to treat these results like those of a colleague sharing some thoughts or preliminary experiments for a few minutes at a lab meeting, rather than a mature paper.
Tom Conerly, Hoagy Cunningham, Adly Templeton, Jack Lindsey, Basil Hosmer, and Adam Jermyn
An earlier version of this page incorrectly wrote the initialization as U(−n1,n1)
U(-\frac{1}{n}, \frac{1}{n}) instead of U(−n1,n1)
U(-\frac{1}{\sqrt{n}}, \frac{1}{\sqrt{n}})
Since our last publication, we’ve made some improvements to how we train sparse autoencoders and crosscoders. While we haven’t extensively ablated all the decisions here, we wanted to share a description of our setup in the hope that it will be a useful starting point for external groups training sparse autoencoders. Our setup uses techniques from Rajamanoharan et al (2024).
Let n
n be the input dimension and o
o the output dimension and m
m be the autoencoder hidden layer dimension. Let s
s be the size of the dataset. Given encoder weights We∈Rm×n
W_e \in R^{m \times n}, decoder weights Wd∈Rn×o
W_d \in R^{n \times o}, log thresholds t∈Rm
t \in R^{m}, biases be∈Rm,bd∈Ro
b_e \in R^{m}, b_d \in R^{o}, and hyperparameters w
w, λS
\lambda_S, λP
\lambda_P, ε
\varepsilon, and c
c, the operations and loss function over a dataset X∈Rs,n,Y∈Rs,o
X \in R^{s,n}, Y \in R^{s,o} with datapoints x∈Rn,y∈Ro
x \in R^{n}, y \in R^{o} are:
Our implementation of JumpReLU uses a straight-through estimator of the gradient through the discontinuity of the nonlinearity as in Rajamanoharan et al (2024), but unlike Rajamanoharan et al. we allow the gradient to flow through straight-through estimator to all model parameters, not just the JumpReLU thresholds. Also note that we use a tanh penalty to encourage sparsity rather than the penalty introduced by Rajamanoharan et al.
LP
\mathcal{L_P}, which we call the pre-act loss, applies a small penalty to features which don't fire. We've found this extremely helpful in reducing dead features. Note that this provides a gradient signal whenever a feature is inactive, so the appropriate scale is a factor of the typical feature activation density lower than the appropriate scale for other loss terms.
We use c=4
c=4, ε=2
\varepsilon=2, λP=3∗10−6
\lambda_P=3\ast10^{-6} and values of λS
\lambda_S around 10. bd
b_d is initialized to all zeros. t
t is initialized to 0.1
0.1.
We initialize Wd
W_d from U(−n1,n1)
U(-\frac{1}{\sqrt{n}}, \frac{1}{\sqrt{n}}). If X=Y
X=Y we initialize We=mnWdT
W_e = \frac{n}{m}W_d^T. If X≠Y
X \ne Y, we initialize We
W_e from U(−m1,m1)
U(-\frac{1}{\sqrt{m}}, \frac{1}{\sqrt{m}}).
We initialize be
b_e by examining a subset of the data and picking a constant per feature such that each feature activates m10000
\frac{10000}{m} of the time. In aggregate roughly 10,000 features will fire per datapoint. We think this initialization is important for avoiding dead features.
The rows of the dataset X
X are shuffled. The dataset is scaled by a single constant such that Ex∈X[∣∣x∣∣2]=n
\mathbb{E}_{\mathbb{x} \in X}[||x||_2] = \sqrt{n}. The goal of this change is for the same value of λS
\lambda_S to mean the same thing across datasets generated by different size transformers.
During training we use Adam optimizer beta1=0.9, beta2=0.999 and no weight decay. Our learning rate varies based on scaling laws, but 2e-4 is a reasonable default. The learning rate is decayed linearly to zero over the last 20% of training. We vary training steps based on scaling laws. We use batch size 32,768 which we believe to be under the critical batch size. The gradient norm is clipped to 1 (using clip_grad_norm). We vary λS
\lambda_S during training, it is initially 0 and linearly increases to its final value over the entire training period. A reasonable default for λS
\lambda_S is 20 given our other parameter settings. We warmup λS
\lambda_S linearly over the entire duration of training.
Conceptually a feature’s activation is now fi∣∣Wd,i∣∣2
\mathbf{f}_i ||W_{d,i}||_2 instead of fi
\mathbf{f}_i. To simplify our analysis code we construct a model which makes identical predictions but has an L2 norm of 1 on the columns of Wd
W_d. We do this by We′=We∣∣Wd∣∣2
W_e' = W_e ||W_d||_2, be′=be∣∣Wd∣∣2
b_e' = b_e ||W_d||_2, Wd′=∣∣Wd∣∣2Wd
W_d' = \frac{W_d}{||W_d||_2} and bd′=bd
b_d'=b_d.